--- ../../swvox/swvox/csrc/rt_kernel.cu	2024-02-14 05:16:57.537407540 -0500
+++ ../google/swvox/swvox/csrc/rt_kernel.cu	2024-02-13 10:46:24.584645005 -0500
@@ -28,6 +28,8 @@
 #include <vector>
 #include "common.cuh"
 #include "data_spec_packed.cuh"
+#include "include/data_spec.hpp"
+#include "wavelets.cuh"
 
 namespace {
 
@@ -39,12 +41,14 @@
     cudaGetDeviceProperties(&dev_prop, 0);
     const int n_cores = get_sp_cores(dev_prop);
     // Optimize number of CUDA threads per block
+    // Divided by 2, as on v100 was giving the following error
+    // Error in swvox.volume_render_backward : too many resources requested for launch
     if (n_cores < 2048) {
-        cuda_n_threads = 256;
+        cuda_n_threads = 128; //256;
     } if (n_cores < 8192) {
-        cuda_n_threads = 512;
+        cuda_n_threads = 256; //512;
     } else {
-        cuda_n_threads = 1024;
+        cuda_n_threads = 512; //1024;
     }
 }
 
@@ -117,6 +121,11 @@
     const scalar_t* __restrict__ dir,
     scalar_t* __restrict__ out) {
     switch(format) {
+        case FORMAT_RGBA:
+        {
+            out[0] = C0; // = 1.0; changed to match plenoxels
+        }  // RGBA
+        break;
         case FORMAT_ASG:
             {
                 // UNTESTED ASG
@@ -181,11 +190,141 @@
             break;
 
         default:
+            out[0] = C0; // = 1.0; changed to match plenoxels
+            break;
+    }  // switch
+}
+
+
+// Calculate wavelet basis functions depending on format
+template <typename scalar_t>
+__device__ __inline__ void maybe_precalc_wavelet(
+    const int format,
+    const int lowpass_depth, 
+    const bool eval_wavelet_integral,
+    const scalar_t* __restrict__ all_relative_pos,
+    const scalar_t* __restrict__ cube_sz,
+    const scalar_t* __restrict__ dir,
+    const int valid_nodes,
+    const scalar_t __restrict__ t_subcube,
+    int* n_wavelet_basis,
+    scalar_t* __restrict__ wavelet_basis) {
+    switch(format) {
+      case HAAR: case TRILINEAR: case SIDE: case DB2: case BIOR22:
+            {
+                // For haar, we could use dda to evaluate exacly the quadrants where it is -1/+1, but we keep it general with N queries to the wavelet_function as allows to switch wavelet easily                
+                int n_basis;
+                if (format == HAAR){
+                    n_basis = 7;
+                }else if (format == TRILINEAR){
+                    n_basis = 8;
+                }else if (format == SIDE){
+                    n_basis = 2;
+                } else if (format == DB2){
+                    n_basis = 8;
+                } else if (format == BIOR22){
+                    n_basis = 8;
+                }
+
+
+                *n_wavelet_basis = n_basis;
+
+                if (lowpass_depth < -1){
+                    printf("ERROR, lowpass_depth should be >= -1 \n");
+                    return;
+                } 
+                if (lowpass_depth >= valid_nodes){
+                    printf("ERROR, lowpass_depth (%d) should be < valid_nodes (%d). This means that the octree tree is not fully instantiated up to the lowpass depth!\n", lowpass_depth, valid_nodes);
+                    return;
+                }
+                if (lowpass_depth >= 0){
+                    // set all coefficeints to 0, up to lowpass depth. and the first one for lowpass_depth level to 1
+                    for (int i=0; i < lowpass_depth; i++){
+                        for (int j=0; j < n_basis; j++){
+                            wavelet_basis[i * n_basis + j] = 0;
+                        }
+                    }
+                    wavelet_basis[lowpass_depth * n_basis] = 1;
+                    for (int j=1; j < n_basis; j++){
+                        wavelet_basis[lowpass_depth * n_basis + j] = 0;
+                    }
+                }
+                
+                for (int i=lowpass_depth + 1; i < valid_nodes; i++){
+                    // TODO: do it with ray tracing along the ray. for now we only evaluate at pos, which is at the face of the wavelet. 
+                    // The "correct" implementation should integrate along the ray from the start to the end of the cube 
+                    if (not eval_wavelet_integral){
+                        //printf("evaluating haar\n");
+                        if (format == HAAR){
+                            evaluate_haar(&all_relative_pos[i * 3], &wavelet_basis[i * n_basis]);
+                        } else if (format == TRILINEAR ){
+                            evaluate_trilinear(&all_relative_pos[i * 3], &wavelet_basis[i * n_basis]);
+                        } else if (format == SIDE ){
+                            evaluate_side(&all_relative_pos[i * 3], &wavelet_basis[i * n_basis]);
+                        } else if (format == DB2 ){
+                            evaluate_db2(&all_relative_pos[i * 3], &wavelet_basis[i * n_basis]);
+                        } else if (format == BIOR22 ){
+                            evaluate_bior22(&all_relative_pos[i * 3], &wavelet_basis[i * n_basis]);
+                        }
+
+                    }
+                    else{                        
+                        // The relative pos are scaled relative to the node (see query_path_from_root).
+                        // when computing the integral, we need adjust t_subcube wich is the distance along the ray for the finer scale to, each of the coarser scales (i.e. *0.5 everytime we go a scale up) 
+                        // TODO: do the ray tracing if necessary
+                        for (int k=0; k < N_WAVELET_EVALUATIONS; k++){
+                            scalar_t current_position[3];
+                            // cube_sz is the resolution (2 smaller size has 2x cube_sz)
+                            scalar_t relative_size = cube_sz[i] / cube_sz[valid_nodes -1];
+                            if (relative_size > 1.f){
+                                printf("Relative size is bigger than 1.0, there's an error!");
+                            }
+                            current_position[0] = all_relative_pos[i * 3 + 0] + dir[0] * k / N_WAVELET_EVALUATIONS * t_subcube * relative_size;
+                            current_position[1] = all_relative_pos[i * 3 + 1] + dir[1] * k / N_WAVELET_EVALUATIONS * t_subcube * relative_size;
+                            current_position[2] = all_relative_pos[i * 3 + 2] + dir[2] * k / N_WAVELET_EVALUATIONS * t_subcube * relative_size;
+
+                            clamp_coord<scalar_t>(current_position);
+
+                            // on the first k we assign the wavelet values (=), on the next ones we just add to compute the (+=) sum, and then divide over samples to get the mean
+                            bool add = k != 0;
+                            if (format == HAAR){
+                                evaluate_haar(current_position, &wavelet_basis[i * n_basis], add);
+                            } else if (format == TRILINEAR ){
+                                evaluate_trilinear(current_position, &wavelet_basis[i * n_basis], add);
+                            } else if (format == SIDE ){
+                                evaluate_side(current_position, &wavelet_basis[i * n_basis], add);
+                            } else if (format == DB2 ){
+                                evaluate_db2(current_position, &wavelet_basis[i * n_basis], add);
+                            } else if (format == BIOR22){
+                                evaluate_bior22(current_position, &wavelet_basis[i * n_basis], add);
+                            }
+                        }
+                        // average
+                        for (int j=0; j < n_basis; j++){
+                            wavelet_basis[i * n_basis + j] /= N_WAVELET_EVALUATIONS;
+                        }
+                    }
+                }
+            }
+            break;
+        case CONSTANT:
+            {
+                *n_wavelet_basis = 1;
+                for (int i=0; i < valid_nodes; i++){
+                    // this is just the "constant" wavelet for now
+                    wavelet_basis[i] = 1;
+                }
+            }  // "CONSTANT" wavelet
+            break;
+        default:
             // Do nothing
             break;
     }  // switch
 }
 
+
+
+
 template <typename scalar_t>
 __device__ __inline__ scalar_t _get_delta_scale(
     const scalar_t* __restrict__ scaling,
@@ -219,11 +358,12 @@
     }
 }
 
+
 template <typename scalar_t>
 __device__ __inline__ void trace_ray(
         PackedTreeSpec<scalar_t>& __restrict__ tree,
-        SingleRaySpec<scalar_t> ray,
-        RenderOptions& __restrict__ opt,
+        SingleRaySpec2<scalar_t> ray,
+        RenderOptionsSwvox& __restrict__ opt,
         torch::TensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, int32_t> out) {
     const scalar_t delta_scale = _get_delta_scale(tree.scaling, ray.dir);
 
@@ -232,7 +372,26 @@
     const int tree_N = tree.child.size(1);
     const int data_dim = tree.data.size(4);
     const int out_data_dim = out.size(0);
+    if (out_data_dim != 3){
+        printf("ERROR: out_data_dim should be 3 and is %d", out_data_dim);\
+        return;
+    }
+    
+    const bool accumulate = opt.accumulate;
+    const bool piecewise_linear = opt.piecewise_linear;
+    const bool linear_color = opt.linear_color;
+    const int32_t* max_depth_ptr = tree.max_depth; // Assuming tree.max_depth is a pointer to const int32_t
+    bool wavelet_sigma = opt.wavelet_sigma;
+
 
+    int32_t tree_max_depth;
+    // Check if max_depth_ptr is valid before assigning the value
+    if (max_depth_ptr != nullptr) {
+        tree_max_depth = *max_depth_ptr;
+    } else {
+        printf("error!");
+    }
+    
 #pragma unroll
     for (int i = 0; i < 3; ++i) {
         invdir[i] = 1.0 / (ray.dir[i] + 1e-9);
@@ -249,73 +408,245 @@
         for (int j = 0; j < out_data_dim; ++j) {
             out[j] = 0.f;
         }
+        // Up to 25 is implemented but we only use up to 16
+        // Set to {15, 15, 15, 15} to disable
+        // k is <= this number 
+        // int num_sh[4] = {0, 3, 8, 15};
+        // int num_sh[4] = {15, 15, 15, 15};                
+
         scalar_t pos[3];
         scalar_t basis_fn[25];
-        maybe_precalc_basis<scalar_t>(opt.format, opt.basis_dim,
-                tree.extra_data, ray.vdir, basis_fn);
+        maybe_precalc_basis<scalar_t>(opt.format, 
+                                      opt.basis_dim, 
+                                      tree.extra_data, ray.vdir, basis_fn);
 
         scalar_t light_intensity = 1.f;
         scalar_t t = tmin;
-        scalar_t cube_sz;
+
+        // See Mip-NeRF: E.2. Activation Functions
         const scalar_t d_rgb_pad = 1 + 2 * opt.rgb_padding;
+        scalar_t prev_sigma = 0.;
+        scalar_t prev_color[3] = {0., 0., 0.};
+        // tmax - t < sqrt(3)
         while (t < tmax) {
             for (int j = 0; j < 3; ++j) {
-                pos[j] = ray.origin[j] + t * ray.dir[j];
+                pos[j] = ray.origin[j] + t * ray.dir[j] + ray.random_offset[j];
+                // TODO(machc): clamp?
             }
 
-            int64_t node_id;
-            scalar_t* tree_val = query_single_from_root<scalar_t>(tree.data, tree.child,
-                        pos, &cube_sz, tree.weight_accum != nullptr ? &node_id : nullptr);
+            scalar_t* data_ptrs[MAX_TREE_DEPTH + 1];
+            int64_t node_ids[MAX_TREE_DEPTH + 1];
+            scalar_t all_cube_sz[MAX_TREE_DEPTH + 1];
+            // relative pos at each voxel, between [0,1],
+            // e.g. absolute position 0.8 with three voxel subdivisions will return 
+            // [(0.8 - 0.5)/0.5, (0.8 - 0.5 - 0.25)/0.25, (0.8 - 0.5 - 0.25 - 0.0)/0.125] =
+            // [0.6, 0.2, 0.4]
+            scalar_t all_relative_pos[(MAX_TREE_DEPTH + 1) * 3];
+
+            // maximum number of coefficients is max depth of the tree * 7 basis for haar
+            scalar_t wavelet_evals[(MAX_TREE_DEPTH + 1) * MAX_WAVELET_SIZE];
+            
+            int valid_nodes = 0;
+            if (accumulate){
+                // The commented initialization creates an illegal memory access because of the new (maybe we need to use cudaMalloc and then liberate), 
+                // To avoid this we just declare it as above, which works avoiding errors/mem leaks, although introduces extra memory when not using accumulate.
+                // The solution may be to properly delete the memory? If we do this, then the MAX_TREE_DEPTH can be dynamic instead.
+                
+                // also should be faster
+
+                // data_ptrs = new scalar_t*[MAX_TREE_DEPTH + 1];
+                // node_ids = new int64_t[MAX_TREE_DEPTH + 2];
+                
+                query_path_from_root<scalar_t>(tree.data,
+                                               tree.child,
+                                               pos, 
+                                               all_relative_pos,
+                                               all_cube_sz, 
+                                               data_ptrs,
+                                               tree_max_depth,
+                                               &node_ids[0]);
+
+                while (valid_nodes < tree_max_depth + 1){
+                    if (node_ids[valid_nodes] == INVALID_NODE_ID){
+                        if (valid_nodes == 0){
+                            printf("ERROR, some node id is invalid as it was already -1 on the root!");  
+                        }
+                        break;
+                    }
+                    valid_nodes += 1;
+                }
+                if (valid_nodes == MAX_TREE_DEPTH + 1){
+                    printf("Error, reached out depth == MAX_TREE_DEPTH, there's a bug or your tree is too big!!!");
+                    return;
+                }
+
+            }
+            else{  // !accumulate
+                if (opt.accumulate_sigma) {
+                    printf("ERROR: accumulate_sigma is set to true, but accumulate is false!");
+                }
+                // just use the leaf, and use fast querying
+                data_ptrs[0] = query_single_from_root<scalar_t>(tree.data, 
+                    tree.child,
+                    pos, 
+                    all_cube_sz, 
+                    &node_ids[0]);
+                
+                valid_nodes = 1;
+            }
+            
+            int64_t leaf_node_id = node_ids[valid_nodes -1];
+            scalar_t* leaf_val = data_ptrs[valid_nodes - 1];
+            scalar_t leaf_cube_sz = all_cube_sz[valid_nodes -1];
 
             scalar_t att;
             scalar_t subcube_tmin, subcube_tmax;
+            //pos get's modified in place by query_single/path_from_root, to give the relative pos of finest grained voxel
             _dda_unit(pos, invdir, &subcube_tmin, &subcube_tmax);
 
-            const scalar_t t_subcube = (subcube_tmax - subcube_tmin) / cube_sz;
+            scalar_t t_subcube = (subcube_tmax - subcube_tmin) / leaf_cube_sz;
             const scalar_t delta_t = t_subcube + opt.step_size;
-            scalar_t sigma = tree_val[data_dim - 1];
+
+            // here evaluate wavelets, depending on wavelet type. For now, we only have the constant wavelet
+            // but changing the wavelet is just a matter of creating a new function and calling it here, that 
+            // will take into account position, cube_size, viewdir...
+            
+            int n_wavelet_basis;
+            maybe_precalc_wavelet<scalar_t>(
+                opt.wavelet_type,
+                opt.lowpass_depth, 
+                opt.eval_wavelet_integral,
+                all_relative_pos, all_cube_sz, ray.vdir, 
+                valid_nodes, 
+                t_subcube, 
+                &n_wavelet_basis, 
+                wavelet_evals);
+
+            // the number of colors is n_channels * n_wavelet_basis * opt.basis_dim (sh);
+            const int color_dim = out_data_dim * n_wavelet_basis * opt.basis_dim;
+
+            if (!wavelet_sigma){
+                if (color_dim + 1 != data_dim){
+                    printf("ERROR: Data dim (%d) should equal color_dim + 1 (%d) (for sigma)\n", data_dim, color_dim + 1);
+                }
+            } else if (color_dim + n_wavelet_basis != data_dim){
+                printf("ERROR: Data dim (%d) should equal color_dim + n_wavelet_basis (%d) (for sigma)\n", data_dim, color_dim + n_wavelet_basis);
+            }
+
+
+            scalar_t sigma;
+            if (opt.accumulate_sigma){
+                sigma = 0.f;
+                if (opt.lowpass_depth < 0){
+                    printf("ERROR! Lowpass depth < 0!");
+                }
+                for (int i = opt.lowpass_depth; i < valid_nodes; i++){
+                    if (!wavelet_sigma){
+                        sigma += data_ptrs[i][data_dim - 1];
+                    } else {
+                        // accumulate over wavelets, same as color but without sh
+                        for (int j = 0; j < n_wavelet_basis; ++j ){
+                            sigma += data_ptrs[i][color_dim + j] * wavelet_evals[i * n_wavelet_basis + j];
+                        }
+                    }
+                }
+            } else {  // !opt.accumulate_sigma
+                if (!wavelet_sigma){
+                    sigma = leaf_val[data_dim - 1];
+                } else {
+                    for (int j = 0; j < n_wavelet_basis; ++j ){
+                        // use the wavelet evaluation of the leaf, which is valid_nodes - 1.
+                        sigma += leaf_val[color_dim + j] * wavelet_evals[(valid_nodes - 1) * n_wavelet_basis + j];
+                    }
+                }
+            }
+
+            // For piecewise linear, use (sigma + prev_sigma) / 2
+            if (piecewise_linear){
+                if (t != tmin) {
+                    scalar_t new_sigma = sigma;
+                    sigma = (sigma + prev_sigma) / 2.0;
+                    prev_sigma = new_sigma;
+                }
+                else prev_sigma = sigma;
+            }
             if (opt.density_softplus) sigma = _SOFTPLUS_M1(sigma);
             if (sigma > opt.sigma_thresh) {
+                // attenuation.
                 att = expf(-delta_t * delta_scale * sigma);
+                // light_intensity is Ti.
                 const scalar_t weight = light_intensity * (1.f - att);
 
-                if (opt.format != FORMAT_RGBA) {
-                    for (int t = 0; t < out_data_dim; ++ t) {
-                        int off = t * opt.basis_dim;
-                        scalar_t tmp = 0.0;
-                        for (int i = opt.min_comp; i <= opt.max_comp; ++i) {
-                            tmp += basis_fn[i] * tree_val[off + i];
+                if (opt.render_distance){
+                    out[0] += weight * t;
+                }
+                else{
+                    for (int l = 0; l < out_data_dim; ++l) {
+                      // color channel (l)
+                      // basis dim is 16 (use 16 SH)
+                        int off = l * n_wavelet_basis * opt.basis_dim;
+                        scalar_t color = linear_color ? 0.5f : 0.0f; // To match plenoxels when linear, it should be biased to 0.5
+                        for (int i = 0; i < valid_nodes; i++){
+                      // n_wavelet_basis is 7, number of haar components.
+                            for (int j = 0; j < n_wavelet_basis; ++j ){
+                              // This goes from 0 to 15 by default [no. of SH]
+                                for (int k = opt.min_comp; k <= opt.max_comp; ++k) {
+                                // for (int k = opt.min_comp; k <= num_sh[min(i / 2, 3)]; ++k) {
+                                  // coefficient for each SH/wavelet component
+                                    color += data_ptrs[i][off + opt.basis_dim * j + k] *
+                                           wavelet_evals[i * n_wavelet_basis + j] *
+                                           // SH evaluated at given direction
+                                           basis_fn[k];
+                                }
+                            }
+                        }
+                        if (piecewise_linear){
+                            // Smooth color.
+                            if (t != tmin) {
+                            scalar_t new_color = color;
+                            color = (prev_color[l] + color) / 2.0;
+                            prev_color[l] = new_color;
+                            }
+                            else prev_color[l] = color;
+                        }
+                        if (linear_color){
+                            if (color > 0.f) { // ReLU
+                                out[l] += weight * color;
+                            }
+                        } else {
+                            out[l] += weight * (_SIGMOID(color) * d_rgb_pad - opt.rgb_padding);
                         }
-                        out[t] += weight * (_SIGMOID(tmp) * d_rgb_pad - opt.rgb_padding);
-                    }
-                } else {
-                    for (int j = 0; j < out_data_dim; ++j) {
-                        out[j] += weight * (_SIGMOID(tree_val[j]) * d_rgb_pad - opt.rgb_padding);
                     }
                 }
+                
+
                 light_intensity *= att;
 
                 if (tree.weight_accum != nullptr) {
                     if (tree.weight_accum_max) {
-                        atomicMax(&tree.weight_accum[node_id], weight);
+                        atomicMax(&tree.weight_accum[leaf_node_id], weight);
                     } else {
-                        atomicAdd(&tree.weight_accum[node_id], weight);
+                        atomicAdd(&tree.weight_accum[leaf_node_id], weight);
                     }
                 }
 
                 if (light_intensity <= opt.stop_thresh) {
                     // Full opacity, stop
-                    scalar_t scale = 1.0 / (1.0 - light_intensity);
-                    for (int j = 0; j < out_data_dim; ++j) {
-                        out[j] *= scale;
+                    if (not opt.render_distance){
+                        scalar_t scale = 1.0 / (1.0 - light_intensity);
+                        for (int j = 0; j < out_data_dim; ++j) {
+                            out[j] *= scale;
+                        }
                     }
                     return;
                 }
             }
             t += delta_t;
         }
-        for (int j = 0; j < out_data_dim; ++j) {
-            out[j] += light_intensity * opt.background_brightness;
+        if (not opt.render_distance){
+            for (int j = 0; j < out_data_dim; ++j) {
+                out[j] += light_intensity * opt.background_brightness;
+            }
         }
     }
 }
@@ -325,8 +656,8 @@
     PackedTreeSpec<scalar_t>& __restrict__ tree,
     const torch::TensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, int32_t>
         grad_output,
-        SingleRaySpec<scalar_t> ray,
-        RenderOptions& __restrict__ opt,
+        SingleRaySpec2<scalar_t> ray,
+        RenderOptionsSwvox& __restrict__ opt,
     torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits>
         grad_data_out) {
     const scalar_t delta_scale = _get_delta_scale(tree.scaling, ray.dir);
@@ -337,6 +668,29 @@
     const int data_dim = tree.data.size(4);
     const int out_data_dim = grad_output.size(0);
 
+    const bool linear_color = opt.linear_color;
+    const bool accumulate = opt.accumulate;
+    const bool piecewise_linear = opt.piecewise_linear;
+    bool wavelet_sigma = opt.wavelet_sigma;
+
+    const float sigma_penalty = opt.sigma_penalty;
+    const bool absolute_values = opt.backward_absolute_values;
+
+    if (opt.render_distance){
+        printf("Cannot backprop if render_distance is active! Render sitance is only for visualization purposes.");
+        return;
+    }
+
+    const int32_t* max_depth_ptr = tree.max_depth; // Assuming tree.max_depth is a pointer to const int32_t
+
+    int32_t tree_max_depth;
+    // Check if max_depth_ptr is valid before assigning the value
+    if (max_depth_ptr != nullptr) {
+        tree_max_depth = *max_depth_ptr;
+    } else {
+        printf("error!");
+    }
+
 #pragma unroll
     for (int i = 0; i < 3; ++i) {
         invdir[i] = 1.0 / (ray.dir[i] + 1e-9);
@@ -347,6 +701,14 @@
         // Ray doesn't hit box
         return;
     } else {
+// The backward repeats a lot of the forward pass to save memory.
+        //int num_sh[4] = {0, 3, 8, 15};
+        scalar_t prev_sigma = 0.;
+        scalar_t prev_color[3] = {0., 0., 0.};
+        scalar_t* prev_grad_leaf_val = nullptr;
+        // init to nullptr?
+        scalar_t* prev_grad_data_ptrs[MAX_TREE_DEPTH + 1] = {};
+        // int num_sh[4] = {15, 15, 15, 15};              
         scalar_t pos[3];
         scalar_t basis_fn[25];
         maybe_precalc_basis<scalar_t>(opt.format, opt.basis_dim, tree.extra_data,
@@ -356,369 +718,435 @@
         const scalar_t d_rgb_pad = 1 + 2 * opt.rgb_padding;
         // PASS 1
         {
-            scalar_t light_intensity = 1.f, t = tmin, cube_sz;
+            scalar_t light_intensity = 1.f, t = tmin;
             while (t < tmax) {
-                for (int j = 0; j < 3; ++j) pos[j] = ray.origin[j] + t * ray.dir[j];
+                for (int j = 0; j < 3; ++j) pos[j] = ray.origin[j] + t * ray.dir[j] + ray.random_offset[j];
 
-                const scalar_t* tree_val = query_single_from_root<scalar_t>(
-                        tree.data, tree.child, pos, &cube_sz);
-                // Reuse offset on gradient
-                const int64_t curr_leaf_offset = tree_val - tree.data.data();
-                scalar_t* grad_tree_val = grad_data_out.data() + curr_leaf_offset;
+                // Same as forward, but also get ptrs to where the gradients should be gathered to
+                int valid_nodes = 0;
 
-                scalar_t att;
+                scalar_t* data_ptrs[MAX_TREE_DEPTH + 1];
+                scalar_t* grad_data_ptrs[MAX_TREE_DEPTH + 1] = {};
+
+                int64_t node_ids[MAX_TREE_DEPTH + 1];
+                scalar_t all_cube_sz[MAX_TREE_DEPTH + 1];
+                // maximum number of coefficients is max depth of the tree * 3 basis for haar
+                scalar_t wavelet_evals[(MAX_TREE_DEPTH + 1) * MAX_WAVELET_SIZE];
+                
+                scalar_t all_relative_pos_data[(MAX_TREE_DEPTH + 1) * 3];
+                scalar_t *all_relative_pos = all_relative_pos_data;
+
+                if (accumulate){
+                    query_path_from_root<scalar_t>(tree.data,
+                                                   tree.child,
+                                                   pos, 
+                                                   all_relative_pos,
+                                                   all_cube_sz, 
+                                                   data_ptrs,
+                                                   tree_max_depth,
+                                                   &node_ids[0]);
+    
+                    while (valid_nodes < tree_max_depth + 1){
+                        if (node_ids[valid_nodes] == INVALID_NODE_ID){
+                            if (valid_nodes == 0){
+                                printf("ERROR, some node id is invalid as it was already -1 on the root!");  
+                            }
+                            break;
+                        }
+                        valid_nodes += 1;
+                    }
+                    if (valid_nodes == MAX_TREE_DEPTH + 1){
+                        printf("ERROR, reached out depth == MAX_TREE_DEPTH, there's a bug or your tree is too big!!!");
+                        return;
+                    }
+                }
+                else{
+                    // just use the leaf, and use fast querying
+                    data_ptrs[0] = query_single_from_root<scalar_t>(tree.data, 
+                        tree.child,
+                        pos, 
+                        all_cube_sz, 
+                        &node_ids[0]);
+                    
+                    valid_nodes = 1;
+                }
+                
+                //int64_t leaf_node_id = node_ids[valid_nodes -1];
+                scalar_t* leaf_val = data_ptrs[valid_nodes - 1];
+                scalar_t leaf_cube_sz = all_cube_sz[valid_nodes -1];
+
+                // populate pointers to gradient data structure, with same offsets as data
+                for (int i =0; i < valid_nodes; ++i){
+                    // Reuse offset on gradient data structure, which has same size as data
+                    const int64_t curr_leaf_offset = data_ptrs[i] - tree.data.data();
+                    grad_data_ptrs[i] = grad_data_out.data() + curr_leaf_offset;
+                }
+
+
+                scalar_t att;  // attenuation, a_i, 1-exp(-sigma * delta)
                 scalar_t subcube_tmin, subcube_tmax;
                 _dda_unit(pos, invdir, &subcube_tmin, &subcube_tmax);
 
-                const scalar_t t_subcube = (subcube_tmax - subcube_tmin) / cube_sz;
+                scalar_t t_subcube = (subcube_tmax - subcube_tmin) / leaf_cube_sz;
                 const scalar_t delta_t = t_subcube + opt.step_size;
-                scalar_t sigma = tree_val[data_dim - 1];
+
+                int n_wavelet_basis;
+                maybe_precalc_wavelet<scalar_t>(
+                    opt.wavelet_type, opt.lowpass_depth, opt.eval_wavelet_integral, all_relative_pos, all_cube_sz, ray.vdir, 
+                    valid_nodes, t_subcube, &n_wavelet_basis, wavelet_evals);
+                
+
+                const int color_dim = out_data_dim * n_wavelet_basis * opt.basis_dim;
+                // same as forward. TODO: abstract into an inline function
+                scalar_t sigma;
+                if (opt.accumulate_sigma){
+                    sigma = 0.f;
+                    if (opt.lowpass_depth < 0){
+                        printf("ERROR! Lowpass depth < 0!");
+                    }
+                    for (int i = opt.lowpass_depth; i < valid_nodes; i++){
+                        if (!wavelet_sigma){
+                            sigma += data_ptrs[i][data_dim - 1];
+                        } else {
+                            // accumulate over wavelets, same as color but without sh
+                            for (int j = 0; j < n_wavelet_basis; ++j ){
+                                sigma += data_ptrs[i][color_dim + j] * wavelet_evals[i * n_wavelet_basis + j];
+                            }
+                        }
+                    }
+                } else {
+                    if (!wavelet_sigma){
+                        sigma = leaf_val[data_dim - 1];
+                    } else {
+                        for (int j = 0; j < n_wavelet_basis; ++j ){
+                            // use the wavelet evaluation of the leaf, which is valid_nodes - 1.
+                            sigma += leaf_val[color_dim + j] * wavelet_evals[(valid_nodes - 1) * n_wavelet_basis + j];
+                        }
+                    }
+                }
+    
+                if (piecewise_linear){
+                    if (t != tmin) {
+                    scalar_t new_sigma = sigma;
+                    sigma = (sigma + prev_sigma) / 2.0;
+                    prev_sigma = new_sigma;
+                    }
+                    else prev_sigma = sigma;
+                }
                 if (opt.density_softplus) sigma = _SOFTPLUS_M1(sigma);
+                // else, it is implicitly a ReLU, as the original paper
                 if (sigma > 0.0) {
+                    // TODO(machc): for piecewise linear, use (sigma + prev_sigma) / 2                  
                     att = expf(-delta_t * sigma * delta_scale);
+                    // T_i a_i 
                     const scalar_t weight = light_intensity * (1.f - att);
 
                     scalar_t total_color = 0.f;
-                    if (opt.format != FORMAT_RGBA) {
-                        for (int t = 0; t < out_data_dim; ++ t) {
-                            int off = t * opt.basis_dim;
-                            scalar_t tmp = 0.0;
-                            for (int i = opt.min_comp; i <= opt.max_comp; ++i) {
-                                tmp += basis_fn[i] * tree_val[off + i];
+                    // For each color channel
+                    for (int l = 0; l < out_data_dim; ++l) {
+                        int off = l  * n_wavelet_basis * opt.basis_dim;
+                        scalar_t color = linear_color ? 0.5f : 0.0f; // To match plenoxels when linear, it should be biased to 0.5
+                        for (int i = 0; i < valid_nodes; i++){
+                            for (int j = 0; j < n_wavelet_basis; ++j ){
+                                // TODO(machc): use fewer degrees for lower resolutions                              
+                                // TODO: expose on the main renderer, not on the code
+                                for (int k = opt.min_comp; k <= opt.max_comp; ++k)
+                                // for (int k = opt.min_comp; k <= num_sh[min(i / 2, 3)]; ++k)
+                                {                                  
+                                    //tree_level[i]; data_ptrs C[l] x N_wavelets[j] x N_sh_basis[k]
+                                    color += data_ptrs[i][off + opt.basis_dim * j + k] * 
+                                           wavelet_evals[i * n_wavelet_basis + j] * 
+                                           basis_fn[k];
+                                }
                             }
-                            const scalar_t sigmoid = _SIGMOID(tmp);
-                            const scalar_t tmp2 = weight * sigmoid * (1.0 - sigmoid) *
-                                                 grad_output[t] * d_rgb_pad;
-                            for (int i = opt.min_comp; i <= opt.max_comp; ++i) {
-                                const scalar_t toadd = basis_fn[i] * tmp2;
-                                atomicAdd(&grad_tree_val[off + i],
-                                        toadd);
+                        }
+                        if (piecewise_linear){
+                            if (t != tmin) {
+                            scalar_t new_color = color;
+                            color = (prev_color[l] + color) / 2.0;
+                            prev_color[l] = new_color;
                             }
-                            total_color += (sigmoid * d_rgb_pad - opt.rgb_padding)
-                                            * grad_output[t];
+                            else prev_color[l] = color;
                         }
-                    } else {
-                        for (int j = 0; j < out_data_dim; ++j) {
-                            const scalar_t sigmoid = _SIGMOID(tree_val[j]);
-                            const scalar_t toadd = weight * sigmoid * (
-                                    1.f - sigmoid) * grad_output[j] * d_rgb_pad;
-                            atomicAdd(&grad_tree_val[j], toadd);
-                            total_color += (sigmoid * d_rgb_pad - opt.rgb_padding)
-                                            * grad_output[j];
+                        scalar_t sigmoid;
+                        scalar_t tmp2;
+                        if (linear_color){
+                            if (color > 0.f) { // ReLU
+                                sigmoid = color;
+                                tmp2 = weight * grad_output[l];                        
+                            } else {
+                                sigmoid = 0.f;
+                                // derivative is also 0 if we are on the left side of the ReLU
+                                tmp2 = 0.f;
+                            }
+                        } else{
+                           // grad_output is the backprop gradients wrt each output [rgb]
+                           // color is where we accumulate the colors, tmp2 is grad wrt color
+                            sigmoid = _SIGMOID(color);
+                            tmp2 = weight * sigmoid * (1.0 - sigmoid) * grad_output[l] * d_rgb_pad;
+                        }
+                        for (int i = 0; i < valid_nodes; ++i){
+                            for (int j = 0; j < n_wavelet_basis; ++j ){
+                              // TODO(machc): use fewer degrees for lower resolutions
+                                // for (int k = opt.min_comp; k <= num_sh[min(i / 2, 3)]; ++k) {
+                                for (int k = opt.min_comp; k <= opt.max_comp; ++k) {
+                                    // grad wrt SH coefficients
+                                    const scalar_t toadd_grad = wavelet_evals[i * n_wavelet_basis + j] *  basis_fn[k] * tmp2;
+                                    // Split grad between this and prev sample (except on first step or on levels not present in previous sample).
+                                    // TODO(machc): this is not exactly right when the previous sample is at a finer scale than the present, but it should be rare.
+                                    if (prev_grad_data_ptrs[i] != nullptr && piecewise_linear) {
+                                        atomicAdd(&grad_data_ptrs[i][off + opt.basis_dim * j + k], (absolute_values and (toadd_grad < 0) ? -1 : 1) * toadd_grad / 2.0);
+                                        atomicAdd(&prev_grad_data_ptrs[i][off + opt.basis_dim * j + k], (absolute_values and (toadd_grad < 0) ? -1 : 1) * toadd_grad / 2.0);
+                                    }
+                                    else atomicAdd(&grad_data_ptrs[i][off + opt.basis_dim * j + k], (absolute_values and (toadd_grad < 0) ? -1 : 1) * toadd_grad);
+                                }
+                            }
                         }
+                        // total_color sums over RGB channels and is used in
+                        // the second pass for sigma gradients.
+                        if (linear_color){
+                            total_color += sigmoid * grad_output[l];
+                        } else {
+                            total_color += (sigmoid * d_rgb_pad - opt.rgb_padding) * grad_output[l];
+                        }
+                        
                     }
+
                     light_intensity *= att;
+                    // To be used in second pass.
                     accum += weight * total_color;
                 }
+
+                for (int i = 0; i < MAX_TREE_DEPTH+1; ++i)
+                  if (i < valid_nodes) prev_grad_data_ptrs[i] = grad_data_ptrs[i];
+                  else prev_grad_data_ptrs[i] = nullptr;
+
                 t += delta_t;
             }
             scalar_t total_grad = 0.f;
             for (int j = 0; j < out_data_dim; ++j)
                 total_grad += grad_output[j];
+            // residual bg color
             accum += light_intensity * opt.background_brightness * total_grad;
         }
+
         // PASS 2
         {
             // scalar_t accum_lo = 0.0;
-            scalar_t light_intensity = 1.f, t = tmin, cube_sz;
+            scalar_t light_intensity = 1.f, t = tmin;
             while (t < tmax) {
                 for (int j = 0; j < 3; ++j) pos[j] = ray.origin[j] + t * ray.dir[j];
-                const scalar_t* tree_val = query_single_from_root<scalar_t>(tree.data,
-                        tree.child, pos, &cube_sz);
-                // Reuse offset on gradient
-                const int64_t curr_leaf_offset = tree_val - tree.data.data();
-                scalar_t* grad_tree_val = grad_data_out.data() + curr_leaf_offset;
-
-                scalar_t att;
-                scalar_t subcube_tmin, subcube_tmax;
-                _dda_unit(pos, invdir, &subcube_tmin, &subcube_tmax);
-
-                const scalar_t t_subcube = (subcube_tmax - subcube_tmin) / cube_sz;
-                const scalar_t delta_t = t_subcube + opt.step_size;
-                scalar_t sigma = tree_val[data_dim - 1];
-                const scalar_t raw_sigma = sigma;
-                if (opt.density_softplus) sigma = _SOFTPLUS_M1(sigma);
-                if (sigma > 0.0) {
-                    att = expf(-delta_t * sigma * delta_scale);
-                    const scalar_t weight = light_intensity * (1.f - att);
-
-                    scalar_t total_color = 0.f;
-                    if (opt.format != FORMAT_RGBA) {
-                        for (int t = 0; t < out_data_dim; ++ t) {
-                            int off = t * opt.basis_dim;
-                            scalar_t tmp = 0.0;
-                            for (int i = opt.min_comp; i <= opt.max_comp; ++i) {
-                                tmp += basis_fn[i] * tree_val[off + i];
+                
+                scalar_t* data_ptrs[MAX_TREE_DEPTH + 1];
+                scalar_t* grad_data_ptrs[MAX_TREE_DEPTH + 1];
+
+                int64_t node_ids[MAX_TREE_DEPTH + 1];
+                scalar_t all_cube_sz[MAX_TREE_DEPTH + 1];
+                scalar_t wavelet_evals[(MAX_TREE_DEPTH + 1) * MAX_WAVELET_SIZE];
+                scalar_t all_relative_pos[(MAX_TREE_DEPTH + 1) * 3];
+
+                int valid_nodes = 0;
+                if (accumulate){
+                    query_path_from_root<scalar_t>(tree.data,
+                                                   tree.child,
+                                                   pos, 
+                                                   all_relative_pos,
+                                                   all_cube_sz, 
+                                                   data_ptrs,
+                                                   tree_max_depth,
+                                                   &node_ids[0]);
+    
+                    while (valid_nodes < tree_max_depth + 1){
+                        if (node_ids[valid_nodes] == INVALID_NODE_ID){
+                            if (valid_nodes == 0){
+                                printf("ERROR, some node id is invalid as it was already -1 on the root!");  
                             }
-                            total_color += (_SIGMOID(tmp) * d_rgb_pad - opt.rgb_padding)
-                                            * grad_output[t];
-                        }
-                    } else {
-                        for (int j = 0; j < out_data_dim; ++j) {
-                            total_color += (_SIGMOID(tree_val[j]) * d_rgb_pad - opt.rgb_padding)
-                                            * grad_output[j];
+                            break;
                         }
+                        valid_nodes += 1;
+                    }
+                    if (valid_nodes == MAX_TREE_DEPTH + 1){
+                        printf("Error, reached out depth == MAX_TREE_DEPTH, there's a bug or your tree is too big!!!");
+                        return;
                     }
-                    light_intensity *= att;
-                    accum -= weight * total_color;
-                    atomicAdd(
-                            &grad_tree_val[data_dim - 1],
-                            delta_t * delta_scale * (
-                                total_color * light_intensity - accum)
-                                *  (opt.density_softplus ?
-                                    _SIGMOID(raw_sigma - 1)
-                                    : 1)
-                            );
                 }
-                t += delta_t;
-            }
-        }
-    }
-}  // trace_ray_backward
-
-template <typename scalar_t>
-__device__ __inline__ void trace_ray_se_grad_hess(
-    PackedTreeSpec<scalar_t>& __restrict__ tree,
-    SingleRaySpec<scalar_t> ray,
-    RenderOptions& __restrict__ opt,
-    torch::TensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, int32_t> color_ref,
-    torch::TensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, int32_t> color_out,
-    torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits>
-        grad_data_out,
-    torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits>
-        hessdiag_out) {
-    const scalar_t delta_scale = _get_delta_scale(tree.scaling, ray.dir);
-
-    scalar_t tmin, tmax;
-    scalar_t invdir[3];
-    const int tree_N = tree.child.size(1);
-    const int data_dim = tree.data.size(4);
-    const int out_data_dim = color_out.size(0);
-
-#pragma unroll
-    for (int i = 0; i < 3; ++i) {
-        invdir[i] = 1.0 / (ray.dir[i] + 1e-9);
-    }
-    _dda_unit(ray.origin, invdir, &tmin, &tmax);
-
-    if (tmax < 0 || tmin > tmax) {
-        // Ray doesn't hit box
-        for (int j = 0; j < out_data_dim; ++j) {
-            color_out[j] = opt.background_brightness;
-        }
-        return;
-    } else {
-        scalar_t pos[3];
-        scalar_t basis_fn[25];
-        maybe_precalc_basis<scalar_t>(opt.format, opt.basis_dim, tree.extra_data,
-                ray.vdir, basis_fn);
+                else{
+                    // just use the leaf, and use fast querying
+                    data_ptrs[0] = query_single_from_root<scalar_t>(tree.data, 
+                        tree.child,
+                        pos, 
+                        all_cube_sz, 
+                        &node_ids[0]);
+                    
+                    valid_nodes = 1;
+                }
+                
 
-        const scalar_t d_rgb_pad = 1 + 2 * opt.rgb_padding;
+                scalar_t* leaf_val = data_ptrs[valid_nodes - 1];
+                scalar_t leaf_cube_sz = all_cube_sz[valid_nodes -1];
 
-        // PASS 1 - compute residual (trace_ray_se_grad_hess)
-        {
-            scalar_t light_intensity = 1.f, t = tmin, cube_sz;
-            while (t < tmax) {
-                for (int j = 0; j < 3; ++j) {
-                    pos[j] = ray.origin[j] + t * ray.dir[j];
+                // populate pointers to gradient data structure, with same offsets as data
+                for (int i =0; i < valid_nodes; ++i){
+                    // Reuse offset on gradient data structure, which has same size as data
+                    const int64_t curr_leaf_offset = data_ptrs[i] - tree.data.data();
+                    grad_data_ptrs[i] = grad_data_out.data() + curr_leaf_offset;
                 }
-
-                scalar_t* tree_val = query_single_from_root<scalar_t>(tree.data, tree.child,
-                        pos, &cube_sz, nullptr);
+                scalar_t* grad_leaf_val = grad_data_ptrs[valid_nodes - 1];
 
                 scalar_t att;
                 scalar_t subcube_tmin, subcube_tmax;
                 _dda_unit(pos, invdir, &subcube_tmin, &subcube_tmax);
 
-                const scalar_t t_subcube = (subcube_tmax - subcube_tmin) / cube_sz;
+                scalar_t t_subcube = (subcube_tmax - subcube_tmin) / leaf_cube_sz;
                 const scalar_t delta_t = t_subcube + opt.step_size;
-                scalar_t sigma = tree_val[data_dim - 1];
-                if (opt.density_softplus) sigma = _SOFTPLUS_M1(sigma);
-                if (sigma > 0.0f) {
-                    att = expf(-delta_t * delta_scale * sigma);
-                    const scalar_t weight = light_intensity * (1.f - att);
 
-                    if (opt.format != FORMAT_RGBA) {
-                        for (int t = 0; t < out_data_dim; ++ t) {
-                            int off = t * opt.basis_dim;
-                            scalar_t tmp = 0.0;
-                            for (int i = opt.min_comp; i <= opt.max_comp; ++i) {
-                                tmp += basis_fn[i] * tree_val[off + i];
+                int n_wavelet_basis;
+                maybe_precalc_wavelet<scalar_t>(
+                    opt.wavelet_type, opt.lowpass_depth, opt.eval_wavelet_integral, all_relative_pos, all_cube_sz, ray.vdir, 
+                    valid_nodes, t_subcube, &n_wavelet_basis, wavelet_evals);
+    
+                const int color_dim = out_data_dim * n_wavelet_basis * opt.basis_dim;             
+                scalar_t sigma;
+                if (opt.accumulate_sigma){
+                    sigma = 0.f;
+                    if (opt.lowpass_depth < 0){
+                        printf("ERROR! Lowpass depth < 0!");
+                    }
+                    for (int i = opt.lowpass_depth; i < valid_nodes; i++){
+                        if (!wavelet_sigma){
+                            sigma += data_ptrs[i][data_dim - 1];
+                        } else {
+                            // accumulate over wavelets, same as color but without sh
+                            for (int j = 0; j < n_wavelet_basis; ++j ){
+                                sigma += data_ptrs[i][color_dim + j] * wavelet_evals[i * n_wavelet_basis + j];
                             }
-                            color_out[t] += weight * (_SIGMOID(tmp) * d_rgb_pad - opt.rgb_padding);
                         }
+                    }
+                } else {
+                    if (!wavelet_sigma){
+                        sigma = leaf_val[data_dim - 1];
                     } else {
-                        for (int j = 0; j < out_data_dim; ++j) {
-                            color_out[j] += weight * (_SIGMOID(tree_val[j]) *
-                                    d_rgb_pad - opt.rgb_padding);
+                        for (int j = 0; j < n_wavelet_basis; ++j ){
+                            // use the wavelet evaluation of the leaf, which is valid_nodes - 1.
+                            sigma += leaf_val[color_dim + j] * wavelet_evals[(valid_nodes - 1) * n_wavelet_basis + j];
                         }
                     }
-                    light_intensity *= att;
                 }
-                t += delta_t;
-            }
-            // Add background intensity & color -> residual
-            for (int j = 0; j < out_data_dim; ++j) {
-                color_out[j] += light_intensity * opt.background_brightness - color_ref[j];
-            }
-        }
 
-        // PASS 2 - compute RGB gradient & suffix (trace_ray_se_grad_hess)
-        scalar_t color_accum[4] = {0, 0, 0, 0};
-        {
-            scalar_t light_intensity = 1.f, t = tmin, cube_sz;
-            while (t < tmax) {
-                for (int j = 0; j < 3; ++j) pos[j] = ray.origin[j] + t * ray.dir[j];
-
-                const scalar_t* tree_val = query_single_from_root<scalar_t>(
-                        tree.data, tree.child, pos, &cube_sz);
-                // Reuse offset on gradient
-                const int64_t curr_leaf_offset = tree_val - tree.data.data();
-                scalar_t* grad_tree_val = grad_data_out.data() + curr_leaf_offset;
-                scalar_t* hessdiag_tree_val = hessdiag_out.data() + curr_leaf_offset;
-
-                scalar_t att;
-                scalar_t subcube_tmin, subcube_tmax;
-                _dda_unit(pos, invdir, &subcube_tmin, &subcube_tmax);
-
-                const scalar_t t_subcube = (subcube_tmax - subcube_tmin) / cube_sz;
-                const scalar_t delta_t = t_subcube + opt.step_size;
-                scalar_t sigma = tree_val[data_dim - 1];
+                // For piecewise linear, use (sigma + prev_sigma) / 2
+                if (piecewise_linear){
+                    if (t != tmin) {
+                        scalar_t new_sigma = sigma;
+                        sigma = (sigma + prev_sigma) / 2.0;
+                        prev_sigma = new_sigma;
+                    }
+                    else prev_sigma = sigma;
+                }
+                const scalar_t raw_sigma = sigma;
                 if (opt.density_softplus) sigma = _SOFTPLUS_M1(sigma);
                 if (sigma > 0.0) {
                     att = expf(-delta_t * sigma * delta_scale);
                     const scalar_t weight = light_intensity * (1.f - att);
 
-                    if (opt.format != FORMAT_RGBA) {
-                        for (int t = 0; t < out_data_dim; ++ t) {
-                            int off = t * opt.basis_dim;
-                            scalar_t tmp = 0.0;
-                            for (int i = opt.min_comp; i <= opt.max_comp; ++i) {
-                                tmp += basis_fn[i] * tree_val[off + i];
+                    scalar_t total_color = 0.f;
+                    for (int l = 0; l < out_data_dim; ++l) {
+                        int off = l * n_wavelet_basis * opt.basis_dim;
+                        scalar_t color = linear_color ? 0.5f : 0.0f; // To match plenoxels when linear, it should be biased to 0.5
+                        for (int i = 0; i < valid_nodes; i++){
+                            for (int j = 0; j < n_wavelet_basis; ++j ){
+                                // for (int k = opt.min_comp; k <= num_sh[min(i / 2, 3)]; ++k) {
+                                for (int k = opt.min_comp; k <= opt.max_comp; ++k) {
+                                    color += data_ptrs[i][off + opt.basis_dim * j + k] * 
+                                           wavelet_evals[i * n_wavelet_basis + j] * 
+                                           basis_fn[k];
+                                }
                             }
-                            const scalar_t sigmoid = _SIGMOID(tmp);
-                            const scalar_t grad_ci = weight * sigmoid * (1.0 - sigmoid) *
-                                                  d_rgb_pad;
-                            // const scalar_t d2_term =
-                            //     (1.f - 2.f * sigmoid) * color_out[t];
-                            for (int i = opt.min_comp; i <= opt.max_comp; ++i) {
-                                const scalar_t grad_wi = basis_fn[i] * grad_ci;
-                                atomicAdd(&grad_tree_val[off + i], grad_wi * color_out[t]);
-                                atomicAdd(&hessdiag_tree_val[off + i],
-                                        // grad_wi * basis_fn[i] * (grad_ci +
-                                        //         d2_term)                   // Newton
-                                        grad_wi * grad_wi                     // Gauss-Newton
-                                    );
+                        }
+                        if (piecewise_linear){
+                            if (t != tmin) {
+                            scalar_t new_color = color;
+                            color = (prev_color[l] + color) / 2.0;
+                            prev_color[l] = new_color;
                             }
-                            const scalar_t color_j = sigmoid * d_rgb_pad - opt.rgb_padding;
-                            color_accum[t] += weight * color_j;
+                            else prev_color[l] = color;
                         }
-                    } else {
-                        for (int j = 0; j < out_data_dim; ++j) {
-                            const scalar_t sigmoid = _SIGMOID(tree_val[j]);
-                            const scalar_t grad_ci = weight * sigmoid * (
-                                    1.f - sigmoid) * d_rgb_pad;
-                            // const scalar_t d2_term = (1.f - 2.f * sigmoid) * color_out[j];
-                            atomicAdd(&grad_tree_val[j], grad_ci * color_out[j]);
-                            // Newton
-                            // atomicAdd(&hessdiag_tree_val[j], grad_ci * (grad_ci + d2_term));
-                            // Gauss-Newton
-                            atomicAdd(&hessdiag_tree_val[j], grad_ci * grad_ci);
-                            const scalar_t color_j = sigmoid * d_rgb_pad - opt.rgb_padding;
-                            color_accum[j] += weight * color_j;
+                        if (linear_color){
+                            if (color > 0.f) { // ReLU
+                                total_color += color * grad_output[l];
+                            }
+                        } else{
+                            total_color += (_SIGMOID(color) * d_rgb_pad - opt.rgb_padding) * grad_output[l];
                         }
                     }
+
                     light_intensity *= att;
-                }
-                t += delta_t;
-            }
-            for (int j = 0; j < out_data_dim; ++j) {
-                color_accum[j] += light_intensity * opt.background_brightness;
-            }
-        }
+                    accum -= weight * total_color;
+                    scalar_t current_grad_val = delta_t * delta_scale * (total_color * light_intensity - accum) *  (opt.density_softplus ? _SIGMOID(raw_sigma - 1) : 1);
+                    
+                    if (sigma_penalty > 0){
+                        // 0 loss to the output, which . we ignore the fact that it goes through the softplus, so it is an approximation if it is with softplus
+                        // this iteration only renders through leaf_cube_sz for all voxels, so the case of accumulate or not is the same
+                        // bigger cubes will receive more hits, so we penalize them more in practice
+                        current_grad_val += 1 * sigma_penalty / leaf_cube_sz / leaf_cube_sz / leaf_cube_sz;
+                    }
 
-        // PASS 3 - finish computing sigma gradient (trace_ray_se_grad_hess)
-        {
-            scalar_t light_intensity = 1.f, t = tmin, cube_sz;
-            scalar_t color_curr[4];
-            while (t < tmax) {
-                for (int j = 0; j < 3; ++j) pos[j] = ray.origin[j] + t * ray.dir[j];
-                const scalar_t* tree_val = query_single_from_root<scalar_t>(tree.data,
-                        tree.child, pos, &cube_sz);
-                // Reuse offset on gradient
-                const int64_t curr_leaf_offset = tree_val - tree.data.data();
-                scalar_t* grad_tree_val = grad_data_out.data() + curr_leaf_offset;
-                scalar_t* hessdiag_tree_val = hessdiag_out.data() + curr_leaf_offset;
+                    // When smoothing out sigma in the integration,
+                    // also update the previous sigma gradient.
 
-                scalar_t att;
-                scalar_t subcube_tmin, subcube_tmax;
-                _dda_unit(pos, invdir, &subcube_tmin, &subcube_tmax);
+                    // TODO(machc): only the noaccumulate and no
+                    // wavelet sigma are implemented!
 
-                const scalar_t t_subcube = (subcube_tmax - subcube_tmin) / cube_sz;
-                const scalar_t delta_t = t_subcube + opt.step_size;
-                scalar_t sigma = tree_val[data_dim - 1];
-                const scalar_t raw_sigma = sigma;
-                if (opt.density_softplus) sigma = _SOFTPLUS_M1(sigma);
-                if (sigma > 0.0) {
-                    att = expf(-delta_t * sigma * delta_scale);
-                    const scalar_t weight = light_intensity * (1.f - att);
+                    // grad is now split between sigma and prev_sigma.
 
-                    if (opt.format != FORMAT_RGBA) {
-                        for (int u = 0; u < out_data_dim; ++ u) {
-                            int off = u * opt.basis_dim;
-                            scalar_t tmp = 0.0;
-                            for (int i = opt.min_comp; i <= opt.max_comp; ++i) {
-                                tmp += basis_fn[i] * tree_val[off + i];
-                            }
-                            color_curr[u] = _SIGMOID(tmp) * d_rgb_pad - opt.rgb_padding;
-                            color_accum[u] -= weight * color_curr[u];
+                    if (opt.accumulate_sigma){
+                        // TODO: adapt for sigma wavelet
+                        if (opt.lowpass_depth < 0){
+                            printf("ERROR! Lowpass depth < 0!");
                         }
-                    } else {
-                        for (int j = 0; j < out_data_dim; ++j) {
-                            color_curr[j] = _SIGMOID(tree_val[j]) * d_rgb_pad - opt.rgb_padding;
-                            color_accum[j] -= weight * color_curr[j];
+                        for (int i = opt.lowpass_depth; i < valid_nodes; i++){
+                            // The sigma is just a sum of terms sigma_i terms, so we backprop the same quantity to all of them
+                            if (!wavelet_sigma){
+                                atomicAdd(&grad_data_ptrs[i][data_dim - 1], (absolute_values and (current_grad_val < 0) ? -1 : 1) * current_grad_val);
+                            } else {
+                                // add gradients multiplied by wavelet evaluation, at multiple scales
+                                for (int j = 0; j < n_wavelet_basis; ++j ){
+                                    const scalar_t toadd_grad = current_grad_val *  wavelet_evals[i * n_wavelet_basis + j];
+                                    atomicAdd(&grad_data_ptrs[i][color_dim + j], (absolute_values and (toadd_grad < 0) ? -1 : 1) * toadd_grad);
+                                }
+                            }
                         }
-                    }
-                    light_intensity *= att;
-                    for (int j = 0; j < out_data_dim; ++j) {
-                        const scalar_t grad_sigma = delta_t * delta_scale * (
-                                color_curr[j] * light_intensity - color_accum[j]);
-                        // Newton
-                        // const scalar_t grad2_sigma =
-                        //     grad_sigma * (grad_sigma - delta_t * delta_scale * color_out[j]);
-                        // Gauss-Newton
-                        const scalar_t grad2_sigma = grad_sigma * grad_sigma;
-                        if (opt.density_softplus) {
-                            const scalar_t sigmoid = _SIGMOID(raw_sigma - 1);
-                            const scalar_t d_sigmoid = sigmoid * (1.f - sigmoid);
-                            // FIXME not sure this works
-                            atomicAdd(&grad_tree_val[data_dim - 1], grad_sigma *
-                                    color_out[j] * sigmoid);
-                            atomicAdd(&hessdiag_tree_val[data_dim - 1],
-                                    grad2_sigma * sigmoid * sigmoid
-                                    + grad_sigma *  d_sigmoid);
+                    } else {  // !opt.accumulate_sigma
+                        // add to sigma
+                        if (!wavelet_sigma){
+                            // This also updates previous grad!
+                            if (prev_grad_leaf_val != nullptr && piecewise_linear) {
+                              // Split by two the gradient.
+                              atomicAdd(&grad_leaf_val[data_dim - 1], (absolute_values and (current_grad_val < 0) ? -1 : 1) * current_grad_val / 2.0);
+                              atomicAdd(&prev_grad_leaf_val[data_dim - 1], (absolute_values and (current_grad_val < 0) ? -1 : 1) * current_grad_val / 2.0);
+                            }
+                            else atomicAdd(&grad_leaf_val[data_dim - 1], (absolute_values and (current_grad_val < 0) ? -1 : 1) * current_grad_val);
                         } else {
-                            atomicAdd(&grad_tree_val[data_dim - 1],
-                                    grad_sigma * color_out[j]);
-                            atomicAdd(&hessdiag_tree_val[data_dim - 1], grad2_sigma);
+                            for (int j = 0; j < n_wavelet_basis; ++j ){
+                                const scalar_t toadd_grad = current_grad_val * wavelet_evals[(valid_nodes - 1) * n_wavelet_basis + j];
+                                atomicAdd(&grad_leaf_val[color_dim + j], (absolute_values and (toadd_grad < 0) ? -1 : 1) * toadd_grad);
+                            }
                         }
                     }
                 }
                 t += delta_t;
+                if (piecewise_linear){
+                    prev_grad_leaf_val = grad_leaf_val;
+                }
             }
         }
-        // Residual -> color
-        for (int j = 0; j < out_data_dim; ++j) {
-            color_out[j] += color_ref[j];
-        }
     }
-}
+}  // trace_ray_backward
 
 template <typename scalar_t>
 __global__ void render_ray_kernel(
         PackedTreeSpec<scalar_t> tree,
-        PackedRaysSpec<scalar_t> rays,
-        RenderOptions opt,
+        PackedRaysSpecSwvox<scalar_t> rays,
+        RenderOptionsSwvox opt,
     torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits>
         out) {
     CUDA_GET_THREAD_ID(tid, rays.origins.size(0));
@@ -727,7 +1155,7 @@
     scalar_t dir[3] = {rays.dirs[tid][0], rays.dirs[tid][1], rays.dirs[tid][2]};
     trace_ray<scalar_t>(
         tree,
-        SingleRaySpec<scalar_t>{origin, dir, &rays.vdirs[tid][0]},
+        SingleRaySpec2<scalar_t>{origin, dir, &rays.vdirs[tid][0], &rays.random_offset[tid][0]},
         opt,
         out[tid]);
 }
@@ -738,11 +1166,9 @@
     PackedTreeSpec<scalar_t> tree,
     const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits>
         grad_output,
-        PackedRaysSpec<scalar_t> rays,
-        RenderOptions opt,
-    torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits>
-        grad_data_out
-        ) {
+        PackedRaysSpecSwvox<scalar_t> rays,
+        RenderOptionsSwvox opt,
+    torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits> grad_data_out) {
     CUDA_GET_THREAD_ID(tid, rays.origins.size(0));
     scalar_t origin[3] = {rays.origins[tid][0], rays.origins[tid][1], rays.origins[tid][2]};
     transform_coord<scalar_t>(origin, tree.offset, tree.scaling);
@@ -750,7 +1176,7 @@
     trace_ray_backward<scalar_t>(
         tree,
         grad_output[tid],
-        SingleRaySpec<scalar_t>{origin, dir, &rays.vdirs[tid][0]},
+        SingleRaySpec2<scalar_t>{origin, dir, &rays.vdirs[tid][0], &rays.random_offset[tid][0]},
         opt,
         grad_data_out);
 }
@@ -760,7 +1186,7 @@
     int ix, int iy,
     scalar_t* dir,
     scalar_t* origin,
-    const PackedCameraSpec<scalar_t>& __restrict__ cam) {
+    const PackedCameraSpecSwvox<scalar_t>& __restrict__ cam) {
     scalar_t x = (ix - 0.5 * cam.width) / cam.fx;
     scalar_t y = -(iy - 0.5 * cam.height) / cam.fy;
     scalar_t z = sqrtf(x * x + y * y + 1.0);
@@ -774,7 +1200,7 @@
 
 template <typename scalar_t>
 __host__ __device__ __inline__ static void maybe_world2ndc(
-        RenderOptions& __restrict__ opt,
+        RenderOptionsSwvox& __restrict__ opt,
         scalar_t* __restrict__ dir,
         scalar_t* __restrict__ cen, scalar_t near = 1.f) {
     if (opt.ndc_width < 0)
@@ -799,13 +1225,14 @@
 template <typename scalar_t>
 __global__ void render_image_kernel(
     PackedTreeSpec<scalar_t> tree,
-    PackedCameraSpec<scalar_t> cam,
-    RenderOptions opt,
+    PackedCameraSpecSwvox<scalar_t> cam,
+    RenderOptionsSwvox opt,
     torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
         out) {
     CUDA_GET_THREAD_ID(tid, cam.width * cam.height);
     int iy = tid / cam.width, ix = tid % cam.width;
     scalar_t dir[3], origin[3];
+    scalar_t random_offsets[3] = {0.0, 0.0, 0.0};
     cam2world_ray(ix, iy, dir, origin, cam);
     scalar_t vdir[3] = {dir[0], dir[1], dir[2]};
     maybe_world2ndc(opt, dir, origin);
@@ -813,7 +1240,7 @@
     transform_coord<scalar_t>(origin, tree.offset, tree.scaling);
     trace_ray<scalar_t>(
         tree,
-        SingleRaySpec<scalar_t>{origin, dir, vdir},
+        SingleRaySpec2<scalar_t>{origin, dir, vdir, random_offsets},
         opt,
         out[iy][ix]);
 }
@@ -823,8 +1250,8 @@
     PackedTreeSpec<scalar_t> tree,
     const torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
         grad_output,
-    PackedCameraSpec<scalar_t> cam,
-    RenderOptions opt,
+    PackedCameraSpecSwvox<scalar_t> cam,
+    RenderOptionsSwvox opt,
     torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits>
         grad_data_out) {
     CUDA_GET_THREAD_ID(tid, cam.width * cam.height);
@@ -832,191 +1259,56 @@
     scalar_t dir[3], origin[3];
     cam2world_ray(ix, iy, dir, origin, cam);
     scalar_t vdir[3] = {dir[0], dir[1], dir[2]};
+    scalar_t random_offset[3] = {0.0, 0.0, 0.0};
     maybe_world2ndc(opt, dir, origin);
 
     transform_coord<scalar_t>(origin, tree.offset, tree.scaling);
     trace_ray_backward<scalar_t>(
         tree,
         grad_output[iy][ix],
-        SingleRaySpec<scalar_t>{origin, dir, vdir},
+        SingleRaySpec2<scalar_t>{origin, dir, vdir, random_offset},
         opt,
         grad_data_out);
 }
 
-template <typename scalar_t>
-__global__ void se_grad_kernel(
-    PackedTreeSpec<scalar_t> tree,
-    PackedRaysSpec<scalar_t> rays,
-    RenderOptions opt,
-    torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> color_ref,
-    torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> color_out,
-    torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits> grad_out,
-    torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits> hessdiag_out) {
-    CUDA_GET_THREAD_ID(tid, rays.origins.size(0));
-    scalar_t origin[3] = {rays.origins[tid][0], rays.origins[tid][1], rays.origins[tid][2]};
-    transform_coord<scalar_t>(origin, tree.offset, tree.scaling);
-    scalar_t dir[3] = {rays.dirs[tid][0], rays.dirs[tid][1], rays.dirs[tid][2]};
-
-    trace_ray_se_grad_hess<scalar_t>(
-        tree,
-        SingleRaySpec<scalar_t>{origin, dir, &rays.vdirs[tid][0]},
-        opt,
-        color_ref[tid],
-        color_out[tid],
-        grad_out,
-        hessdiag_out);
-}
-
-template <typename scalar_t>
-__global__ void se_grad_persp_kernel(
-    PackedTreeSpec<scalar_t> tree,
-    PackedCameraSpec<scalar_t> cam,
-    RenderOptions opt,
-    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
-        color_ref,
-    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
-        color_out,
-    torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits> grad_out,
-    torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits> hessdiag_out) {
-    CUDA_GET_THREAD_ID(tid, cam.width * cam.height);
-    int iy = tid / cam.width, ix = tid % cam.width;
-    scalar_t dir[3], origin[3];
-    cam2world_ray(ix, iy, dir, origin, cam);
-    scalar_t vdir[3] = {dir[0], dir[1], dir[2]};
-    maybe_world2ndc(opt, dir, origin);
-
-    transform_coord<scalar_t>(origin, tree.offset, tree.scaling);
-    trace_ray_se_grad_hess<scalar_t>(
-        tree,
-        SingleRaySpec<scalar_t>{origin, dir, vdir},
-        opt,
-        color_ref[iy][ix],
-        color_out[iy][ix],
-        grad_out,
-        hessdiag_out);
-}
-
-template <typename scalar_t>
-__device__ __inline__ void grid_trace_ray(
-    const torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
-        data,
-        const scalar_t* __restrict__ origin,
-        const scalar_t* __restrict__ dir,
-        const scalar_t* __restrict__ vdir,
-        scalar_t step_size,
-        scalar_t delta_scale,
-        scalar_t sigma_thresh,
-    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
-        grid_weight,
-    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
-        grid_hit) {
-    scalar_t tmin, tmax;
-    scalar_t invdir[3];
-    const int reso = data.size(0);
-    scalar_t* grid_weight_val = grid_weight.data();
-    scalar_t* grid_hit_val = grid_hit.data();
+}  // namespace device
 
-#pragma unroll
-    for (int i = 0; i < 3; ++i) {
-        invdir[i] = 1.0 / (dir[i] + 1e-9);
+// Compute RGB output dimension from input dimension & SH degree
+__host__ int get_out_data_dim(int format, int basis_dim, int in_data_dim, int wavelet_type, bool wavelet_sigma) {
+    int wavelet_dims;
+    if (wavelet_type == CONSTANT){
+        wavelet_dims = 1;
+    } else if (wavelet_type == HAAR){
+        wavelet_dims = 7;
+    } else if (wavelet_type == TRILINEAR){
+        wavelet_dims = 8;
+    } else if (wavelet_type == SIDE){
+        wavelet_dims = 2;
+    } else if (wavelet_type == DB2){
+        wavelet_dims = 8;
+    } else if (wavelet_type == BIOR22){
+        wavelet_dims = 8;
     }
-    _dda_unit(origin, invdir, &tmin, &tmax);
 
-    if (tmax < 0 || tmin > tmax) {
-        // Ray doesn't hit box
-        return;
+    else { 
+        printf("NOT IMPLEMENTED");
+    }
+    int n_sigma_basis;
+    if (wavelet_sigma){
+        n_sigma_basis = wavelet_dims;
     } else {
-        scalar_t pos[3];
-
-        scalar_t light_intensity = 1.f;
-        scalar_t t = tmin;
-        scalar_t cube_sz = reso;
-        int32_t u, v, w, node_id;
-        while (t < tmax) {
-            for (int j = 0; j < 3; ++j) {
-                pos[j] = origin[j] + t * dir[j];
-            }
-
-            clamp_coord<scalar_t>(pos);
-            pos[0] *= reso;
-            pos[1] *= reso;
-            pos[2] *= reso;
-            u = floor(pos[0]);
-            v = floor(pos[1]);
-            w = floor(pos[2]);
-            pos[0] -= u;
-            pos[1] -= v;
-            pos[2] -= w;
-            node_id = u * reso * reso + v * reso + w;
-
-            scalar_t att;
-            scalar_t subcube_tmin, subcube_tmax;
-            _dda_unit(pos, invdir, &subcube_tmin, &subcube_tmax);
-
-            const scalar_t t_subcube = (subcube_tmax - subcube_tmin) / cube_sz;
-            const scalar_t delta_t = t_subcube + step_size;
-            scalar_t sigma = data[u][v][w];
-            if (sigma > sigma_thresh) {
-                att = expf(-delta_t * delta_scale * sigma);
-                const scalar_t weight = light_intensity * (1.f - att);
-                light_intensity *= att;
-
-                atomicMax(&grid_weight_val[node_id], weight);
-                atomicAdd(&grid_hit_val[node_id], (scalar_t) 1.0);
-            }
-            t += delta_t;
-        }
+        n_sigma_basis = 1;
     }
-}
-
-template <typename scalar_t>
-__global__ void grid_weight_render_kernel(
-    const torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
-        data,
-    PackedCameraSpec<scalar_t> cam,
-    RenderOptions opt,
-    const scalar_t* __restrict__ offset,
-    const scalar_t* __restrict__ scaling,
-    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
-        grid_weight,
-    torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits>
-        grid_hit) {
-    CUDA_GET_THREAD_ID(tid, cam.width * cam.height);
-    int iy = tid / cam.width, ix = tid % cam.width;
-    scalar_t dir[3], origin[3];
-    cam2world_ray(ix, iy, dir, origin, cam);
-    scalar_t vdir[3] = {dir[0], dir[1], dir[2]};
-    maybe_world2ndc(opt, dir, origin);
-
-    transform_coord<scalar_t>(origin, offset, scaling);
-    const scalar_t delta_scale = _get_delta_scale(scaling, dir);
-    grid_trace_ray<scalar_t>(
-        data,
-        origin,
-        dir,
-        vdir,
-        opt.step_size,
-        delta_scale,
-        opt.sigma_thresh,
-        grid_weight,
-        grid_hit);
-}
-
-}  // namespace device
-
-
-// Compute RGB output dimension from input dimension & SH degree
-__host__ int get_out_data_dim(int format, int basis_dim, int in_data_dim) {
     if (format != FORMAT_RGBA) {
-        return (in_data_dim - 1) / basis_dim;
+        return (in_data_dim - n_sigma_basis) / basis_dim / wavelet_dims;
     } else {
-        return in_data_dim - 1;
+        return (in_data_dim - n_sigma_basis) / wavelet_dims;
     }
 }
 
 }  // namespace
 
-torch::Tensor volume_render(TreeSpec& tree, RaysSpec& rays, RenderOptions& opt) {
+torch::Tensor volume_render(TreeSpec& tree, RaysSpecSwvox& rays, RenderOptionsSwvox& opt) {
     tree.check();
     rays.check();
     DEVICE_GUARD(tree.data);
@@ -1024,8 +1316,10 @@
 
     auto_cuda_threads();
     const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
-    int out_data_dim = get_out_data_dim(opt.format, opt.basis_dim, tree.data.size(4));
+    int out_data_dim = get_out_data_dim(opt.format, opt.basis_dim, tree.data.size(4), opt.wavelet_type, opt.wavelet_sigma);
     torch::Tensor result = torch::zeros({Q, out_data_dim}, rays.origins.options());
+    // originally was AT_DISPATCH_FLOATING_TYPES
+    // changed to allow half precision
     AT_DISPATCH_FLOATING_TYPES(rays.origins.type(), __FUNCTION__, [&] {
             device::render_ray_kernel<scalar_t><<<blocks, cuda_n_threads>>>(
                     tree, rays, opt,
@@ -1035,30 +1329,10 @@
     return result;
 }
 
-torch::Tensor volume_render_image(TreeSpec& tree, CameraSpec& cam, RenderOptions& opt) {
-    tree.check();
-    cam.check();
-    DEVICE_GUARD(tree.data);
-    const size_t Q = size_t(cam.width) * cam.height;
-
-    auto_cuda_threads();
-    const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
-    int out_data_dim = get_out_data_dim(opt.format, opt.basis_dim, tree.data.size(4));
-    torch::Tensor result = torch::zeros({cam.height, cam.width, out_data_dim},
-            tree.data.options());
-
-    AT_DISPATCH_FLOATING_TYPES(tree.data.type(), __FUNCTION__, [&] {
-            device::render_image_kernel<scalar_t><<<blocks, cuda_n_threads>>>(
-                    tree, cam, opt,
-                    result.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>());
-    });
-    CUDA_CHECK_ERRORS;
-    return result;
-}
 
 torch::Tensor volume_render_backward(
-    TreeSpec& tree, RaysSpec& rays,
-    RenderOptions& opt,
+    TreeSpec& tree, RaysSpecSwvox& rays,
+    RenderOptionsSwvox& opt,
     torch::Tensor grad_output) {
     tree.check();
     rays.check();
@@ -1068,7 +1342,7 @@
 
     auto_cuda_threads();
     const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
-    int out_data_dim = get_out_data_dim(opt.format, opt.basis_dim, tree.data.size(4));
+    int out_data_dim = get_out_data_dim(opt.format, opt.basis_dim, tree.data.size(4), opt.wavelet_type, opt.wavelet_sigma);
     torch::Tensor result = torch::zeros_like(tree.data);
     AT_DISPATCH_FLOATING_TYPES(rays.origins.type(), __FUNCTION__, [&] {
             device::render_ray_backward_kernel<scalar_t><<<blocks, cuda_n_threads>>>(
@@ -1081,118 +1355,3 @@
     CUDA_CHECK_ERRORS;
     return result;
 }
-
-torch::Tensor volume_render_image_backward(TreeSpec& tree, CameraSpec& cam,
-                                           RenderOptions& opt,
-                                           torch::Tensor grad_output) {
-    tree.check();
-    cam.check();
-    DEVICE_GUARD(tree.data);
-
-    const size_t Q = size_t(cam.width) * cam.height;
-
-    auto_cuda_threads();
-    const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
-    int out_data_dim = get_out_data_dim(opt.format, opt.basis_dim, tree.data.size(4));
-    torch::Tensor result = torch::zeros_like(tree.data);
-
-    AT_DISPATCH_FLOATING_TYPES(tree.data.type(), __FUNCTION__, [&] {
-            device::render_image_backward_kernel<scalar_t><<<blocks, cuda_n_threads>>>(
-                tree,
-                grad_output.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
-                cam,
-                opt,
-                result.packed_accessor64<scalar_t, 5, torch::RestrictPtrTraits>());
-    });
-    CUDA_CHECK_ERRORS;
-    return result;
-}
-
-std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> se_grad(
-        TreeSpec& tree, RaysSpec& rays, torch::Tensor color, RenderOptions& opt) {
-    tree.check();
-    rays.check();
-    DEVICE_GUARD(tree.data);
-    CHECK_INPUT(color);
-
-    const auto Q = rays.origins.size(0);
-
-    auto_cuda_threads();
-    const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
-    int out_data_dim = get_out_data_dim(opt.format, opt.basis_dim, tree.data.size(4));
-    if (out_data_dim > 4) {
-        throw std::runtime_error("Tree's output dim cannot be > 4 for se_grad");
-    }
-    torch::Tensor result = torch::zeros({Q, out_data_dim}, rays.origins.options());
-    torch::Tensor grad = torch::zeros_like(tree.data);
-    torch::Tensor hessdiag = torch::zeros_like(tree.data);
-    AT_DISPATCH_FLOATING_TYPES(rays.origins.type(), __FUNCTION__, [&] {
-            device::se_grad_kernel<scalar_t><<<blocks, cuda_n_threads>>>(
-                    tree, rays, opt,
-                    color.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
-                    result.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
-                    grad.packed_accessor64<scalar_t, 5, torch::RestrictPtrTraits>(),
-                    hessdiag.packed_accessor64<scalar_t, 5, torch::RestrictPtrTraits>());
-    });
-    CUDA_CHECK_ERRORS;
-    return std::template tuple<torch::Tensor, torch::Tensor, torch::Tensor>(result, grad, hessdiag);
-}
-
-std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> se_grad_persp(
-                            TreeSpec& tree,
-                            CameraSpec& cam,
-                            RenderOptions& opt,
-                            torch::Tensor color) {
-    tree.check();
-    cam.check();
-    DEVICE_GUARD(tree.data);
-    CHECK_INPUT(color);
-    const size_t Q = size_t(cam.width) * cam.height;
-
-    auto_cuda_threads();
-    const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
-    int out_data_dim = get_out_data_dim(opt.format, opt.basis_dim, tree.data.size(4));
-    if (out_data_dim > 4) {
-        throw std::runtime_error("Tree's output dim cannot be > 4 for se_grad");
-    }
-    torch::Tensor result = torch::zeros({cam.height, cam.width, out_data_dim},
-            tree.data.options());
-    torch::Tensor grad = torch::zeros_like(tree.data);
-    torch::Tensor hessdiag = torch::zeros_like(tree.data);
-
-    AT_DISPATCH_FLOATING_TYPES(tree.data.type(), __FUNCTION__, [&] {
-            device::se_grad_persp_kernel<scalar_t><<<blocks, cuda_n_threads>>>(
-                    tree, cam, opt,
-                    color.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
-                    result.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
-                    grad.packed_accessor64<scalar_t, 5, torch::RestrictPtrTraits>(),
-                    hessdiag.packed_accessor64<scalar_t, 5, torch::RestrictPtrTraits>());
-    });
-    CUDA_CHECK_ERRORS;
-    return std::template tuple<torch::Tensor, torch::Tensor, torch::Tensor>(result, grad, hessdiag);
-}
-std::vector<torch::Tensor> grid_weight_render(
-    torch::Tensor data, CameraSpec& cam, RenderOptions& opt,
-    torch::Tensor offset, torch::Tensor scaling) {
-    cam.check();
-    DEVICE_GUARD(data);
-    const size_t Q = size_t(cam.width) * cam.height;
-
-    auto_cuda_threads();
-    const int blocks = CUDA_N_BLOCKS_NEEDED(Q, cuda_n_threads);
-    torch::Tensor grid_weight = torch::zeros_like(data);
-    torch::Tensor grid_hit = torch::zeros_like(data);
-
-    AT_DISPATCH_FLOATING_TYPES(data.type(), __FUNCTION__, [&] {
-            device::grid_weight_render_kernel<scalar_t><<<blocks, cuda_n_threads>>>(
-                data.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
-                cam,
-                opt,
-                offset.data<scalar_t>(),
-                scaling.data<scalar_t>(),
-                grid_weight.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
-                grid_hit.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>());
-    });
-    CUDA_CHECK_ERRORS;
-    return {grid_weight, grid_hit};
-}
